<!DOCTYPE HTML>
<html lang="zh-CN">


<head>
    <meta charset="utf-8">
    <meta name="keywords" content="Pytorch入门, J Sir">
    <meta name="description" content="">
    <meta http-equiv="X-UA-Compatible" content="IE=edge">
    <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
    <meta name="renderer" content="webkit|ie-stand|ie-comp">
    <meta name="mobile-web-app-capable" content="yes">
    <meta name="format-detection" content="telephone=no">
    <meta name="apple-mobile-web-app-capable" content="yes">
    <meta name="apple-mobile-web-app-status-bar-style" content="black-translucent">
    <!-- Global site tag (gtag.js) - Google Analytics -->


    <title>Pytorch入门 | J Sir</title>
    <link rel="icon" type="image/png" href="/favicon.png">

    <link rel="stylesheet" type="text/css" href="/libs/awesome/css/all.css">
    <link rel="stylesheet" type="text/css" href="/libs/materialize/materialize.min.css">
    <link rel="stylesheet" type="text/css" href="/libs/aos/aos.css">
    <link rel="stylesheet" type="text/css" href="/libs/animate/animate.min.css">
    <link rel="stylesheet" type="text/css" href="/libs/lightGallery/css/lightgallery.min.css">
    <link rel="stylesheet" type="text/css" href="/css/matery.css">
    <link rel="stylesheet" type="text/css" href="/css/my.css">

    <script src="/libs/jquery/jquery.min.js"></script>

<meta name="generator" content="Hexo 6.0.0">
<style>.github-emoji { position: relative; display: inline-block; width: 1.2em; min-height: 1.2em; overflow: hidden; vertical-align: top; color: transparent; }  .github-emoji > span { position: relative; z-index: 10; }  .github-emoji img, .github-emoji .fancybox { margin: 0 !important; padding: 0 !important; border: none !important; outline: none !important; text-decoration: none !important; user-select: none !important; cursor: auto !important; }  .github-emoji img { height: 1.2em !important; width: 1.2em !important; position: absolute !important; left: 50% !important; top: 50% !important; transform: translate(-50%, -50%) !important; user-select: none !important; cursor: auto !important; } .github-emoji-fallback { color: inherit; } .github-emoji-fallback img { opacity: 0 !important; }</style>
</head>




<body>
    <header class="navbar-fixed">
    <nav id="headNav" class="bg-color nav-transparent">
        <div id="navContainer" class="nav-wrapper container">
            <div class="brand-logo">
                <a href="/" class="waves-effect waves-light">
                    
                    <img src="/medias/logo.png" class="logo-img" alt="LOGO">
                    
                    <span class="logo-span">J Sir</span>
                </a>
            </div>
            

<a href="#" data-target="mobile-nav" class="sidenav-trigger button-collapse"><i class="fas fa-bars"></i></a>
<ul class="right nav-menu">
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/" class="waves-effect waves-light">
      
      <i class="fas fa-home" style="zoom: 0.6;"></i>
      
      <span>首页</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/tags" class="waves-effect waves-light">
      
      <i class="fas fa-tags" style="zoom: 0.6;"></i>
      
      <span>标签</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/categories" class="waves-effect waves-light">
      
      <i class="fas fa-bookmark" style="zoom: 0.6;"></i>
      
      <span>分类</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/archives" class="waves-effect waves-light">
      
      <i class="fas fa-archive" style="zoom: 0.6;"></i>
      
      <span>归档</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/about" class="waves-effect waves-light">
      
      <i class="fas fa-user-circle" style="zoom: 0.6;"></i>
      
      <span>关于</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/contact" class="waves-effect waves-light">
      
      <i class="fas fa-comments" style="zoom: 0.6;"></i>
      
      <span>留言板</span>
    </a>
    
  </li>
  
  <li class="hide-on-med-and-down nav-item">
    
    <a href="/friends" class="waves-effect waves-light">
      
      <i class="fas fa-address-book" style="zoom: 0.6;"></i>
      
      <span>友情链接</span>
    </a>
    
  </li>
  
  <li>
    <a href="#searchModal" class="modal-trigger waves-effect waves-light">
      <i id="searchIcon" class="fas fa-search" title="搜索" style="zoom: 0.85;"></i>
    </a>
  </li>
</ul>


<div id="mobile-nav" class="side-nav sidenav">

    <div class="mobile-head bg-color">
        
        <img src="/medias/logo.png" class="logo-img circle responsive-img">
        
        <div class="logo-name">J Sir</div>
        <div class="logo-desc">
            
            Never really desperate, only the lost of the soul.
            
        </div>
    </div>

    

    <ul class="menu-list mobile-menu-list">
        
        <li class="m-nav-item">
	  
		<a href="/" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-home"></i>
			
			首页
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/tags" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-tags"></i>
			
			标签
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/categories" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-bookmark"></i>
			
			分类
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/archives" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-archive"></i>
			
			归档
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/about" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-user-circle"></i>
			
			关于
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/contact" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-comments"></i>
			
			留言板
		</a>
          
        </li>
        
        <li class="m-nav-item">
	  
		<a href="/friends" class="waves-effect waves-light">
			
			    <i class="fa-fw fas fa-address-book"></i>
			
			友情链接
		</a>
          
        </li>
        
        
        <li><div class="divider"></div></li>
        <li>
            <a href="https://github.com/jy741" class="waves-effect waves-light" target="_blank">
                <i class="fab fa-github-square fa-fw"></i>Fork Me
            </a>
        </li>
        
    </ul>
</div>


        </div>

        
            <style>
    .nav-transparent .github-corner {
        display: none !important;
    }

    .github-corner {
        position: absolute;
        z-index: 10;
        top: 0;
        right: 0;
        border: 0;
        transform: scale(1.1);
    }

    .github-corner svg {
        color: #0f9d58;
        fill: #fff;
        height: 64px;
        width: 64px;
    }

    .github-corner:hover .octo-arm {
        animation: a 0.56s ease-in-out;
    }

    .github-corner .octo-arm {
        animation: none;
    }

    @keyframes a {
        0%,
        to {
            transform: rotate(0);
        }
        20%,
        60% {
            transform: rotate(-25deg);
        }
        40%,
        80% {
            transform: rotate(10deg);
        }
    }
</style>

<a href="https://github.com/jy741" class="github-corner tooltipped hide-on-med-and-down" target="_blank"
   data-tooltip="Fork Me" data-position="left" data-delay="50">
    <svg viewBox="0 0 250 250" aria-hidden="true">
        <path d="M0,0 L115,115 L130,115 L142,142 L250,250 L250,0 Z"></path>
        <path d="M128.3,109.0 C113.8,99.7 119.0,89.6 119.0,89.6 C122.0,82.7 120.5,78.6 120.5,78.6 C119.2,72.0 123.4,76.3 123.4,76.3 C127.3,80.9 125.5,87.3 125.5,87.3 C122.9,97.6 130.6,101.9 134.4,103.2"
              fill="currentColor" style="transform-origin: 130px 106px;" class="octo-arm"></path>
        <path d="M115.0,115.0 C114.9,115.1 118.7,116.5 119.8,115.4 L133.7,101.6 C136.9,99.2 139.9,98.4 142.2,98.6 C133.8,88.0 127.5,74.4 143.8,58.0 C148.5,53.4 154.0,51.2 159.7,51.0 C160.3,49.4 163.2,43.6 171.4,40.1 C171.4,40.1 176.1,42.5 178.8,56.2 C183.1,58.6 187.2,61.8 190.9,65.4 C194.5,69.0 197.7,73.2 200.1,77.6 C213.8,80.2 216.3,84.9 216.3,84.9 C212.7,93.1 206.9,96.0 205.4,96.6 C205.1,102.4 203.0,107.8 198.3,112.5 C181.9,128.9 168.3,122.5 157.7,114.1 C157.9,116.9 156.7,120.9 152.7,124.9 L141.0,136.5 C139.8,137.7 141.6,141.9 141.8,141.8 Z"
              fill="currentColor" class="octo-body"></path>
    </svg>
</a>
        
    </nav>

</header>

    



<div class="bg-cover pd-header post-cover" style="background-image: url('/medias/featureimages/6.jpg')">
    <div class="container" style="right: 0px;left: 0px;">
        <div class="row">
            <div class="col s12 m12 l12">
                <div class="brand">
                    <h1 class="description center-align post-title">Pytorch入门</h1>
                </div>
            </div>
        </div>
    </div>
</div>




<main class="post-container content">

    
    <link rel="stylesheet" href="/libs/tocbot/tocbot.css">
<style>
    #articleContent h1::before,
    #articleContent h2::before,
    #articleContent h3::before,
    #articleContent h4::before,
    #articleContent h5::before,
    #articleContent h6::before {
        display: block;
        content: " ";
        height: 100px;
        margin-top: -100px;
        visibility: hidden;
    }

    #articleContent :focus {
        outline: none;
    }

    .toc-fixed {
        position: fixed;
        top: 64px;
    }

    .toc-widget {
        width: 345px;
        padding-left: 20px;
    }

    .toc-widget .toc-title {
        padding: 35px 0 15px 17px;
        font-size: 1.5rem;
        font-weight: bold;
        line-height: 1.5rem;
    }

    .toc-widget ol {
        padding: 0;
        list-style: none;
    }

    #toc-content {
        padding-bottom: 30px;
        overflow: auto;
    }

    #toc-content ol {
        padding-left: 10px;
    }

    #toc-content ol li {
        padding-left: 10px;
    }

    #toc-content .toc-link:hover {
        color: #42b983;
        font-weight: 700;
        text-decoration: underline;
    }

    #toc-content .toc-link::before {
        background-color: transparent;
        max-height: 25px;

        position: absolute;
        right: 23.5vw;
        display: block;
    }

    #toc-content .is-active-link {
        color: #42b983;
    }

    #floating-toc-btn {
        position: fixed;
        right: 15px;
        bottom: 76px;
        padding-top: 15px;
        margin-bottom: 0;
        z-index: 998;
    }

    #floating-toc-btn .btn-floating {
        width: 48px;
        height: 48px;
    }

    #floating-toc-btn .btn-floating i {
        line-height: 48px;
        font-size: 1.4rem;
    }
</style>
<div class="row">
    <div id="main-content" class="col s12 m12 l9">
        <!-- 文章内容详情 -->
<div id="artDetail">
    <div class="card">
        <div class="card-content article-info">
            <div class="row tag-cate">
                <div class="col s7">
                    
                    <div class="article-tag">
                        
                            <a href="/tags/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/">
                                <span class="chip bg-color">机器学习</span>
                            </a>
                        
                            <a href="/tags/python/">
                                <span class="chip bg-color">python</span>
                            </a>
                        
                    </div>
                    
                </div>
                <div class="col s5 right-align">
                    
                </div>
            </div>

            <div class="post-info">
                
                <div class="post-date info-break-policy">
                    <i class="far fa-calendar-minus fa-fw"></i>发布日期:&nbsp;&nbsp;
                    2022-10-15
                </div>
                

                

                
                <div class="info-break-policy">
                    <i class="far fa-file-word fa-fw"></i>文章字数:&nbsp;&nbsp;
                    2.9k
                </div>
                

                

                
                    <div id="busuanzi_container_page_pv" class="info-break-policy">
                        <i class="far fa-eye fa-fw"></i>阅读次数:&nbsp;&nbsp;
                        <span id="busuanzi_value_page_pv"></span>
                    </div>
				
            </div>
        </div>
        <hr class="clearfix">

        
        <!-- 是否加载使用自带的 prismjs. -->
        <link rel="stylesheet" href="/libs/prism/prism.css">
        

        

        <div class="card-content article-card-content">
            <div id="articleContent">
                <p>本文基于B站up<a target="_blank" rel="noopener" href="https://space.bilibili.com/203989554">我是土堆</a>所做的视频教程为基础创作的学习笔记，如有侵权请联系作者。</p>
<span id="more"></span>
<h1 id="Pytorch入门实战"><a href="#Pytorch入门实战" class="headerlink" title="Pytorch入门实战"></a>Pytorch入门实战</h1><p>Pytorch官网：<a target="_blank" rel="noopener" href="https://pytorch.org/">https://pytorch.org/</a>    </p>
<p>以下介绍的实际应用在官网的Docs下都有明确的教程文档</p>
<h2 id="DataSet"><a href="#DataSet" class="headerlink" title="DataSet"></a>DataSet</h2><pre class="line-numbers language-python" data-language="python"><code class="language-python">from torch.utils.data import Dataset<span aria-hidden="true" class="line-numbers-rows"><span></span></span></code></pre>
<p>Dataset顾名思义是用来自定义数据集的，我们可以理解为Java中的抽象类，需要你自定义一个数据集class来继承他，需要实现他的init()、getitem()、len()三种方法，具体实例如下：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class MyData(Dataset):
    def __init__(self,root_path,label_path):
        self.root_path = root_path
        self.label_path = label_path
        self.path = os.path.join(self.root_path,self.label_path)
        self.img_list = os.listdir(self.path)


    def __getitem__(self, idx):
        img_path = os.path.join(self.root_path,self.label_path,self.img_list[idx])
        image = Image.open(img_path)
        label = self.label_path
        return image,label

    def __len__(self):
        return len(self.img_list)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>接下来我们只需要将我们的数据集复制到对应目录就可以使用了</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">root_path = "data/hymenoptera_data/train"
ants_dataset = MyData(root_path,"ants")
image,label = ants_dataset[0]
image.show()
print(label)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="Tensorboard"><a href="#Tensorboard" class="headerlink" title="Tensorboard"></a>Tensorboard</h2><p>Tensorboard可以对Pytorch训练过程进行可视化展示、主要用到SummaryWriter类</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">writter = SummaryWriter("logs")

image_path = "data/hymenoptera_data/train/ants/5650366_e22b7e1065.jpg"

#add_scalar:参数1：name	参数2：y轴	参数3：x轴
for i in range(100):
    writter.add_scalar("y=2x",2*i,i)


image_PIL = Image.open(image_path)
image_array = np.array(image_PIL)

#参数1：name	参数2：image(需要numpy类型或者tensor类型)	参数3：step
writter.add_image("test",image_array,1,dataformats='HWC')


writter.close()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<blockquote>
<p>运行完代码后，在控制台输入tensorboard —logdir=logs    启动可视化服务器（默认6006端口）</p>
</blockquote>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015165634684.png" alt="image-20221015165634684"></p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015165655253.png" alt="image-20221015165655253"></p>
<p>SummaryWriter除了能可视化线型图和图片外，还能将模型结构通过图化展示</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )


    def forward(self,x):
        x = self.model(x)
        return x

tudui = Tudui()
input = torch.ones(64,3,32,32)
output = tudui(input)
print(output.shape)

writer = SummaryWriter("logs")
writer.add_graph(tudui,input)
writer.close()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015170530806.png" alt="image-20221015170530806"></p>
<h2 id="Transforms"><a href="#Transforms" class="headerlink" title="Transforms"></a>Transforms</h2><p>Transforms是Pytorch常用的图像预处理方法，一般用于转化图片类型、大小、正则化、裁剪图像等等，具体应用如下：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">img_path = "images/001.png"

image = Image.open(img_path)

# ToTensor
tensor_totensor = transforms.ToTensor()
img_tensor = tensor_totensor(image)

writer = SummaryWriter("logs")
writer.add_image("tf_test",img_tensor,1)


# Normalize——归一化
tensor_norm = transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_norm = tensor_norm(img_tensor)
writer.add_image("tf_test",img_norm,2)


# Resize
tensor_resize = transforms.Resize(512)
img_resize = tensor_resize(image)    #默认参数PIL图片格式，输出也是PIL格式
# print(type(img_resize))
img_resize = tensor_totensor(img_resize)
writer.add_image("Resize",img_resize,0)

# Compose——将多个处理方法连结在一起
tensor_compose = transforms.Compose([tensor_resize,tensor_totensor])
img_resize2 = tensor_compose(image)
writer.add_image("Resize",img_resize,1)

# RandomCrop——随机裁剪
tensor_random = transforms.RandomCrop(500)
tensor_compose2 = transforms.Compose([tensor_random,tensor_totensor])
for i in range(10):
    img_random = tensor_compose2(image)
    writer.add_image("Random",img_random,i)


writer.close()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>我们一般用Transforms来实现图像由PIL类型转化为Tensor数据类型，因为在Pytorch中Tensor往往是后续很多方法参数所需要的类型</p>
<h2 id="DataLoader"><a href="#DataLoader" class="headerlink" title="DataLoader"></a>DataLoader</h2><p>介绍DataLoader之前我们简单介绍一下如何使用官方给的数据集（torchvision.datasets)</p>
<p>Torchvision 在模块中提供了许多内置数据集<code>torchvision.datasets</code> ，以及用于构建您自己的数据集的实用程序类。</p>
<p>这些内置数据集都继承了上述的Dataset类，不需要我们手动去为它们实现方法</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_set = torchvision.datasets.CIFAR10(root="./data",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./data",train=False,transform=dataset_transform,download=True)

print(train_set.classes)

writer = SummaryWriter("logs")
for i in range(10):
    image,target = train_set[i]
    writer.add_image("dataset",image,i)

writer.close()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>而DataLoader其实就是一个迭代器，方便我们从数据集从一次读取多个数据</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">test_dataset = torchvision.datasets.CIFAR10(root="./data",train=False,transform=torchvision.transforms.ToTensor())

#参数解读：batch_size——一次加载多少个数据	shuffle——下次迭代是否打乱顺序	drop_last——表示是否删除非完整页的结尾数据
test_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=True,drop_last=True)

writer = SummaryWriter("logs")

for each in range(2):
    step = 0
    for data in test_loader:
        imgs,targets = data
        writer.add_images("dataloader_{}".format(each),imgs,step)
        step+=1

writer.close()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="nn-Module"><a href="#nn-Module" class="headerlink" title="nn.Module"></a>nn.Module</h2><p>nn.Module是Pytorch中神经网络的基本骨架，如果我们想自定义一个神经网路模型类，我们必须要继承nn.Module这个类，并在这个类的init方法引用父类的init方法，重写forward()方法（就是input进去经过怎样的转换输出成output），其形参就是模型（块）的输入。</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class Demo(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,input):
        output = input+1
        return output


demo = Demo()
x = torch.tensor(1)
y = demo(x)
print(y)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="function的卷积操作"><a href="#function的卷积操作" class="headerlink" title="function的卷积操作"></a>function的卷积操作</h2><p>卷积操作主要就是用卷积核(kernel)对原始数据进行矩阵计算，得到一个新的输出，具体过程可以参照我《机器学习——Convolution》这篇博客</p>
<p>参数解读：input是输入的数据，kernel表示卷积核，stride表示步长，padding表示周围填充几层</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">import torch.nn.functional as F


input = torch.tensor([[1,2,0,3,1],
                      [0,1,2,3,1],
                      [1,2,1,0,0],
                      [5,2,3,1,1],
                      [2,1,0,1,1]])


kernel = torch.tensor([[1,2,1],
                       [0,1,0],
                       [2,1,0]])

#F要求input的shape：batch_size,channel,H,W
input = torch.reshape(input,(1,1,5,5))
kernel = torch.reshape(kernel,(1,1,3,3))


output1 = F.conv2d(input,kernel,stride=1)
print(output1)

output2 = F.conv2d(input,kernel,stride=1,padding=1)
print(output2)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="卷积层"><a href="#卷积层" class="headerlink" title="卷积层"></a>卷积层</h2><p>其实就是对nn.function的进一步封装，如nn.Conv2(),最常用的是这五个参数：in_channels、 out_channels、kernel_size、stride、 padding</p>
<p>以下示例就是把CIFAR10数据集中的图片经过一层卷积后的结果</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">dataset = torchvision.datasets.CIFAR10("./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)

dataloader = DataLoader(dataset,64,drop_last=True)

writer = SummaryWriter("logs")

class Tudui(torch.nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)

    def forward(self,x):
        x = self.conv1(x)
        return x

step = 0
tudui = Tudui()

for data in dataloader:
    imgs,labels = data
    writer.add_images("input",imgs,step)
    output = tudui(imgs)
    output = torch.reshape(output,(-1,3,30,30))
    writer.add_images("output",output,step)
    step+=1

writer.close()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015174937563.png" alt="image-20221015174937563"></p>
<p><img src="https://cdn.jsdelivr.net/gh/jy741/media@main/img/image-20221015174948827.png" alt="image-20221015174948827"></p>
<h2 id="最大池化"><a href="#最大池化" class="headerlink" title="最大池化"></a>最大池化</h2><p>Maxpooling作用（maxpool2d举例）：</p>
<ul>
<li>一是对卷积层所提取的信息做更一步降维，减少计算量</li>
<li>二是加强图像特征的不变性，使之增加图像的偏移、旋转等方面的鲁棒性</li>
<li>类似于观看视频时不同的清晰度，实际效果就像给图片打马赛克</li>
</ul>
<p>示例：和上面代码差距不大，只是在模型中修改一下</p>
<pre class="line-numbers language-none"><code class="language-none">class Jy(nn.Module):
    def __init__(self):
        super(Jy, self).__init__()
        self.maxpool = MaxPool2d(kernel_size=3,ceil_mode=False)

    def forward(self,x):
        x = self.maxpool(x)
        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015175514221.png" alt="image-20221015175514221"></p>
<h2 id="非线性激活"><a href="#非线性激活" class="headerlink" title="非线性激活"></a>非线性激活</h2><p>非线性变换的主要目的就是给网中加入一些非线性特征，非线性越多才能训练出符合各种特征的模型。常见的非线性激活：Sigmod、ReLu</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.activation = Sigmoid()

    def forward(self,input):
        output = self.activation(input)
        return output<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015175741732.png" alt="image-20221015175741732"></p>
<h2 id="线性层"><a href="#线性层" class="headerlink" title="线性层"></a>线性层</h2><p>线性层其实就是全连接层，也就是CNN过程中的最后一步</p>
<p>线性函数为：torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)，其中重要的3个参数in_features、out_features、bias说明如下：</p>
<ul>
<li>in_features：每个输入（x）样本的特征的大小</li>
<li>out_features：每个输出（y）样本的特征的大小</li>
<li>bias：如果设置为False，则图层不会学习附加偏差。默认值是True，表示增加学习偏置。</li>
</ul>
<blockquote>
<p>进线性层之前我们需要把数据Flatten铺平，变成一维数据，而Linear作用就是缩小一维的数据长度</p>
</blockquote>
<pre class="line-numbers language-none"><code class="language-none">dataset = torchvision.datasets.CIFAR10("./data",False,transform=torchvision.transforms.ToTensor())

dataloader = DataLoader(dataset,64)


class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.liner = Linear(196608,10)

    def forward(self,input):
        output = self.liner(input)
        return output

tudui = Tudui()

for data in dataloader:
    imgs,targets = data
    print(imgs.shape)
    output = torch.flatten(imgs)
    print(output.shape)
    output = tudui(output)
    print(output.shape)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="Sequential的使用"><a href="#Sequential的使用" class="headerlink" title="Sequential的使用"></a>Sequential的使用</h2><p>在我们定义模型时，往往需要经过一个复杂的神经网络结构才能得到最终的输出，在引入Sequential之前，我们来看看CIFAR10模型应该如何定义</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015180345237.png" alt="image-20221015180345237"></p>
<p>这里我们需要自己计算一下padding值是多少，我们根据官方文档给出的公式进行计算得到padding = 2</p>
<p><img src="https://pluto-1300780100.cos.ap-nanjing.myqcloud.com/img/image-20221015201850741.png" alt="image-20221015201850741"></p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.conv1 = Conv2d(3,32,5,padding=2)
        self.maxpool1 = MaxPool2d(2)
        self.conv2 = Conv2d(32,32,5,padding=2)
        self.maxpool2 = MaxPool2d(2)
        self.conv3 = Conv2d(32,64,5,padding=2)
        self.maxpool3 = MaxPool2d(2)
        self.flatten = Flatten()
        self.linear1 = Linear(1024,64)
        self.linear2 = Linear(64,10)
        


    def forward(self,x):
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.conv3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.linear1(x)
        x = self.linear2(x)
        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>而引入Sequential后我们可以大大简化代码的书写：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = Sequential(
            Conv2d(3, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )


    def forward(self,x):
        x = self.model(x)
        return x<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="损失函数与优化器"><a href="#损失函数与优化器" class="headerlink" title="损失函数与优化器"></a>损失函数与优化器</h2><p>损失函数（Loss）：有多种计算方式</p>
<ul>
<li><p>计算实际输出和目标之间的差距</p>
</li>
<li><p>为我们更新输出 提供一定的依据</p>
<p>更新输出：</p>
<ul>
<li>反向传播（backward）–&gt;</li>
<li>计算出梯度（grade）–&gt;</li>
<li>根据梯度和学习率来更新参数–&gt;</li>
<li>减小loss</li>
</ul>
</li>
</ul>
<p>而更新参数需要我们的优化器来做，在优化器中我们需要定义一个学习率Learning_rate，以下为一个简单优化实例：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">tudui = Tudui()

loss = nn.CrossEntropyLoss()

optim = torch.optim.SGD(tudui.parameters(),lr=0.01)

for epch in range(10):
    sum_loss = 0
    for data in dataloader:
        imgs,targets = data
        output = tudui(imgs)
        result_loss = loss(output,targets)
        optim.zero_grad()	#梯度置0
        result_loss.backward()	#计算梯度
        optim.step()	#根据梯度更新参数
        sum_loss += result_loss
    print(sum_loss)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="现有的网络模型"><a href="#现有的网络模型" class="headerlink" title="现有的网络模型"></a>现有的网络模型</h2><p>我们可以直接加载torchvision给我们的模型，其中pretrained参数Flase表示使用默认神经网络参数，True表示使用ImageNet数据集更新好的参数</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">import torchvision
from torch import nn

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)

#在原有的模型上新增一个线性层
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))

print(vgg16_true)<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<h2 id="完整过程实例"><a href="#完整过程实例" class="headerlink" title="完整过程实例"></a>完整过程实例</h2><p>训练实例：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">import torch
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import *

#定义数据集
train_dataset = torchvision.datasets.CIFAR10("../data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.CIFAR10("../data",train=False,transform=torchvision.transforms.ToTensor(),download=True)

#数据集长度
train_dataset_len = len(train_dataset)
test_dataset_len = len(test_dataset)
print("测试集长度：{}".format(train_dataset_len))
print("训练集长度：{}".format(test_dataset_len))

#定义dataloader
train_dataloader = DataLoader(train_dataset,64)
test_dataloader = DataLoader(test_dataset,64)

#中间结果绘图
writer = SummaryWriter("practice_log")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#定义神经网络
tudui = Tudui()
tudui = tudui.to(device)

#定义训练/测试步数
train_step = 0
test_step = 0

#定义loss function
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

#定义优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)

#定义训练次数
train_number = 10

for i in range(train_number):
    print("第{}轮训练开始".format(i+1))


    #训练步骤
    tudui.train() #模型进入训练模式
    for data in train_dataloader:
        imgs,targets = data
        imgs = imgs.to(device)
        targets = targets.to(device)
        outputs = tudui(imgs)
        train_loss = loss_fn(outputs,targets)
        #优化模型
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        train_step += 1
        if(train_step % 100 == 0):
            print("训练次数:{}  Loss:{}".format(train_step,train_loss))
            writer.add_scalar("train_loss",train_loss.item(),train_step)

    #测试步骤
    tudui.eval()    #模型进入验证模式
    test_sum_loss = 0
    test_accuray = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs,targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = tudui(imgs)
            test_loss = loss_fn(outputs,targets)
            test_sum_loss += test_loss
            #argmax()函数：标记出行或者列上最大的值    参数：1——行 0——列
            accuray = (outputs.argmax(1) == targets).sum()
            test_accuray += accuray

        print("整体测试机上的Loss:{}".format(test_sum_loss))
        print("整体测试机上的正确率:{}".format(test_accuray/test_dataset_len))
        writer.add_scalar("test_loss",test_sum_loss,test_step)
        writer.add_scalar("test_accuray",test_accuray/test_dataset_len,test_step)

        test_step+=1


    torch.save(tudui,"tudui_{}.pth".format(i))
    print("模型已保存")

writer.close()<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>
<p>验证实例：</p>
<pre class="line-numbers language-python" data-language="python"><code class="language-python">import torch
import torchvision
from PIL import Image
from torch import nn

img_path = "./frog_1.png"
image = Image.open(img_path)

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
                                            torchvision.transforms.ToTensor()])

image = transform(image)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )

    def forward(self,x):
        x = self.model(x)
        return x


model = torch.load("tudui_29.pth",map_location=torch.device('cpu'))

image = torch.reshape(image,(1,3,32,32))

model.eval()

with torch.no_grad():
    output = model(image)

print(output.argmax(1).item())<span aria-hidden="true" class="line-numbers-rows"><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span><span></span></span></code></pre>

                
            </div>
            <hr/>

            

    <div class="reprint" id="reprint-statement">
        
            <div class="reprint__author">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-user">
                        文章作者:
                    </i>
                </span>
                <span class="reprint-info">
                    <a href="/about" rel="external nofollow noreferrer">J Sir</a>
                </span>
            </div>
            <div class="reprint__type">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-link">
                        文章链接:
                    </i>
                </span>
                <span class="reprint-info">
                    <a href="https://jy741.gitee.io/2022/10/15/pytorch-ru-men/">https://jy741.gitee.io/2022/10/15/pytorch-ru-men/</a>
                </span>
            </div>
            <div class="reprint__notice">
                <span class="reprint-meta" style="font-weight: bold;">
                    <i class="fas fa-copyright">
                        版权声明:
                    </i>
                </span>
                <span class="reprint-info">
                    本博客所有文章除特別声明外，均采用
                    <a href="https://creativecommons.org/licenses/by/4.0/deed.zh" rel="external nofollow noreferrer" target="_blank">CC BY 4.0</a>
                    许可协议。转载请注明来源
                    <a href="/about" target="_blank">J Sir</a>
                    !
                </span>
            </div>
        
    </div>

    <script async defer>
      document.addEventListener("copy", function (e) {
        let toastHTML = '<span>复制成功，请遵循本文的转载规则</span><button class="btn-flat toast-action" onclick="navToReprintStatement()" style="font-size: smaller">查看</a>';
        M.toast({html: toastHTML})
      });

      function navToReprintStatement() {
        $("html, body").animate({scrollTop: $("#reprint-statement").offset().top - 80}, 800);
      }
    </script>



            <div class="tag_share" style="display: block;">
                <div class="post-meta__tag-list" style="display: inline-block;">
                    
                        <div class="article-tag">
                            
                                <a href="/tags/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0/">
                                    <span class="chip bg-color">机器学习</span>
                                </a>
                            
                                <a href="/tags/python/">
                                    <span class="chip bg-color">python</span>
                                </a>
                            
                        </div>
                    
                </div>
                <div class="post_share" style="zoom: 80%; width: fit-content; display: inline-block; float: right; margin: -0.15rem 0;">
                    <link rel="stylesheet" type="text/css" href="/libs/share/css/share.min.css">
<div id="article-share">

    
    <div class="social-share" data-sites="twitter,facebook,google,qq,qzone,wechat,weibo,douban,linkedin" data-wechat-qrcode-helper="<p>微信扫一扫即可分享！</p>"></div>
    <script src="/libs/share/js/social-share.min.js"></script>
    

    

</div>

                </div>
            </div>
            
                <style>
    #reward {
        margin: 40px 0;
        text-align: center;
    }

    #reward .reward-link {
        font-size: 1.4rem;
        line-height: 38px;
    }

    #reward .btn-floating:hover {
        box-shadow: 0 6px 12px rgba(0, 0, 0, 0.2), 0 5px 15px rgba(0, 0, 0, 0.2);
    }

    #rewardModal {
        width: 320px;
        height: 350px;
    }

    #rewardModal .reward-title {
        margin: 15px auto;
        padding-bottom: 5px;
    }

    #rewardModal .modal-content {
        padding: 10px;
    }

    #rewardModal .close {
        position: absolute;
        right: 15px;
        top: 15px;
        color: rgba(0, 0, 0, 0.5);
        font-size: 1.3rem;
        line-height: 20px;
        cursor: pointer;
    }

    #rewardModal .close:hover {
        color: #ef5350;
        transform: scale(1.3);
        -moz-transform:scale(1.3);
        -webkit-transform:scale(1.3);
        -o-transform:scale(1.3);
    }

    #rewardModal .reward-tabs {
        margin: 0 auto;
        width: 210px;
    }

    .reward-tabs .tabs {
        height: 38px;
        margin: 10px auto;
        padding-left: 0;
    }

    .reward-content ul {
        padding-left: 0 !important;
    }

    .reward-tabs .tabs .tab {
        height: 38px;
        line-height: 38px;
    }

    .reward-tabs .tab a {
        color: #fff;
        background-color: #ccc;
    }

    .reward-tabs .tab a:hover {
        background-color: #ccc;
        color: #fff;
    }

    .reward-tabs .wechat-tab .active {
        color: #fff !important;
        background-color: #22AB38 !important;
    }

    .reward-tabs .alipay-tab .active {
        color: #fff !important;
        background-color: #019FE8 !important;
    }

    .reward-tabs .reward-img {
        width: 210px;
        height: 210px;
    }
</style>

<div id="reward">
    <a href="#rewardModal" class="reward-link modal-trigger btn-floating btn-medium waves-effect waves-light red">赏</a>

    <!-- Modal Structure -->
    <div id="rewardModal" class="modal">
        <div class="modal-content">
            <a class="close modal-close"><i class="fas fa-times"></i></a>
            <h4 class="reward-title">你的赏识是我前进的动力</h4>
            <div class="reward-content">
                <div class="reward-tabs">
                    <ul class="tabs row">
                        <li class="tab col s6 alipay-tab waves-effect waves-light"><a href="#alipay">支付宝</a></li>
                        <li class="tab col s6 wechat-tab waves-effect waves-light"><a href="#wechat">微 信</a></li>
                    </ul>
                    <div id="alipay">
                        <img src="/medias/reward/alipay.jpg" class="reward-img" alt="支付宝打赏二维码">
                    </div>
                    <div id="wechat">
                        <img src="/medias/reward/wechat.png" class="reward-img" alt="微信打赏二维码">
                    </div>
                </div>
            </div>
        </div>
    </div>
</div>

<script>
    $(function () {
        $('.tabs').tabs();
    });
</script>

            
        </div>
    </div>

    

    

    

    

    

    

    

    

<article id="prenext-posts" class="prev-next articles">
    <div class="row article-row">
        
        <div class="article col s12 m6" data-aos="fade-up">
            <div class="article-badge left-badge text-color">
                <i class="fas fa-chevron-left"></i>&nbsp;上一篇</div>
            <div class="card">
                <a href="/2022/10/19/shen-du-xue-xi-tu-xiang-fen-lei-ru-men-parti/">
                    <div class="card-image">
                        
                        
                        <img src="/medias/featureimages/11.jpg" class="responsive-img" alt="深度学习——图像分类入门PartⅠ">
                        
                        <span class="card-title">深度学习——图像分类入门PartⅠ</span>
                    </div>
                </a>
                <div class="card-content article-content">
                    <div class="summary block-with-text">
                        
                            EverydayOneCat⬆️??!??!⬆️ 😯😯⁉️ 🐯⬇️

                        
                    </div>
                    <div class="publish-info">
                        <span class="publish-date">
                            <i class="far fa-clock fa-fw icon-date"></i>2022-10-19
                        </span>
                        <span class="publish-author">
                            
                            <i class="fas fa-user fa-fw"></i>
                            J Sir
                            
                        </span>
                    </div>
                </div>
                
                <div class="card-action article-tags">
                    
                    <a href="/tags/python/">
                        <span class="chip bg-color">python</span>
                    </a>
                    
                    <a href="/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/">
                        <span class="chip bg-color">深度学习</span>
                    </a>
                    
                </div>
                
            </div>
        </div>
        
        
        <div class="article col s12 m6" data-aos="fade-up">
            <div class="article-badge right-badge text-color">
                下一篇&nbsp;<i class="fas fa-chevron-right"></i>
            </div>
            <div class="card">
                <a href="/2022/10/11/shen-du-xue-xi-backpropagation/">
                    <div class="card-image">
                        
                        
                        <img src="/medias/featureimages/6.jpg" class="responsive-img" alt="深度学习——Backpropagation">
                        
                        <span class="card-title">深度学习——Backpropagation</span>
                    </div>
                </a>
                <div class="card-content article-content">
                    <div class="summary block-with-text">
                        
                            本文基于Sakura-gh大佬的机器学习笔记修改，仅作为学习资料备用，如有侵权，联系作者。
                        
                    </div>
                    <div class="publish-info">
                            <span class="publish-date">
                                <i class="far fa-clock fa-fw icon-date"></i>2022-10-11
                            </span>
                        <span class="publish-author">
                            
                            <i class="fas fa-user fa-fw"></i>
                            J Sir
                            
                        </span>
                    </div>
                </div>
                
                <div class="card-action article-tags">
                    
                    <a href="/tags/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/">
                        <span class="chip bg-color">深度学习</span>
                    </a>
                    
                </div>
                
            </div>
        </div>
        
    </div>
</article>

</div>



<!-- 代码块功能依赖 -->
<script type="text/javascript" src="/libs/codeBlock/codeBlockFuction.js"></script>

<!-- 代码语言 -->

<script type="text/javascript" src="/libs/codeBlock/codeLang.js"></script>


<!-- 代码块复制 -->

<script type="text/javascript" src="/libs/codeBlock/codeCopy.js"></script>


<!-- 代码块收缩 -->

<script type="text/javascript" src="/libs/codeBlock/codeShrink.js"></script>


    </div>
    <div id="toc-aside" class="expanded col l3 hide-on-med-and-down">
        <div class="toc-widget card" style="background-color: white;">
            <div class="toc-title"><i class="far fa-list-alt"></i>&nbsp;&nbsp;目录</div>
            <div id="toc-content"></div>
        </div>
    </div>
</div>

<!-- TOC 悬浮按钮. -->

<div id="floating-toc-btn" class="hide-on-med-and-down">
    <a class="btn-floating btn-large bg-color">
        <i class="fas fa-list-ul"></i>
    </a>
</div>


<script src="/libs/tocbot/tocbot.min.js"></script>
<script>
    $(function () {
        tocbot.init({
            tocSelector: '#toc-content',
            contentSelector: '#articleContent',
            headingsOffset: -($(window).height() * 0.4 - 45),
            collapseDepth: Number('0'),
            headingSelector: 'h1, h2, h3, h4, h5'
        });

        // modify the toc link href to support Chinese.
        let i = 0;
        let tocHeading = 'toc-heading-';
        $('#toc-content a').each(function () {
            $(this).attr('href', '#' + tocHeading + (++i));
        });

        // modify the heading title id to support Chinese.
        i = 0;
        $('#articleContent').children('h1, h2, h3, h4, h5').each(function () {
            $(this).attr('id', tocHeading + (++i));
        });

        // Set scroll toc fixed.
        let tocHeight = parseInt($(window).height() * 0.4 - 64);
        let $tocWidget = $('.toc-widget');
        $(window).scroll(function () {
            let scroll = $(window).scrollTop();
            /* add post toc fixed. */
            if (scroll > tocHeight) {
                $tocWidget.addClass('toc-fixed');
            } else {
                $tocWidget.removeClass('toc-fixed');
            }
        });

        
        /* 修复文章卡片 div 的宽度. */
        let fixPostCardWidth = function (srcId, targetId) {
            let srcDiv = $('#' + srcId);
            if (srcDiv.length === 0) {
                return;
            }

            let w = srcDiv.width();
            if (w >= 450) {
                w = w + 21;
            } else if (w >= 350 && w < 450) {
                w = w + 18;
            } else if (w >= 300 && w < 350) {
                w = w + 16;
            } else {
                w = w + 14;
            }
            $('#' + targetId).width(w);
        };

        // 切换TOC目录展开收缩的相关操作.
        const expandedClass = 'expanded';
        let $tocAside = $('#toc-aside');
        let $mainContent = $('#main-content');
        $('#floating-toc-btn .btn-floating').click(function () {
            if ($tocAside.hasClass(expandedClass)) {
                $tocAside.removeClass(expandedClass).hide();
                $mainContent.removeClass('l9');
            } else {
                $tocAside.addClass(expandedClass).show();
                $mainContent.addClass('l9');
            }
            fixPostCardWidth('artDetail', 'prenext-posts');
        });
        
    });
</script>

    

</main>




    <footer class="page-footer bg-color">
    
        <link rel="stylesheet" href="/libs/aplayer/APlayer.min.css">
<style>
    .aplayer .aplayer-lrc p {
        
        display: none;
        
        font-size: 12px;
        font-weight: 700;
        line-height: 16px !important;
    }

    .aplayer .aplayer-lrc p.aplayer-lrc-current {
        
        display: none;
        
        font-size: 15px;
        color: #42b983;
    }

    
    .aplayer.aplayer-fixed.aplayer-narrow .aplayer-body {
        left: -66px !important;
    }

    .aplayer.aplayer-fixed.aplayer-narrow .aplayer-body:hover {
        left: 0px !important;
    }

    
</style>
<div class="">
    
    <div class="row">
        <meting-js class="col l8 offset-l2 m10 offset-m1 s12"
                   server="netease"
                   type="playlist"
                   id="503838841"
                   fixed='true'
                   autoplay='false'
                   theme='#42b983'
                   loop='all'
                   order='random'
                   preload='auto'
                   volume='0.7'
                   list-folded='true'
        >
        </meting-js>
    </div>
</div>

<script src="/libs/aplayer/APlayer.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/meting@2/dist/Meting.min.js"></script>

    
    <div class="container row center-align" style="margin-bottom: 0px !important;">
        <div class="col s12 m8 l8 copy-right">
            Copyright&nbsp;&copy;
            
                <span id="year">2020-2023</span>
            
            <span id="year">2020</span>
            <a href="/about" target="_blank">J Sir</a>
            |&nbsp;Powered by&nbsp;<a href="https://hexo.io/" target="_blank">Hexo</a>
            |&nbsp;Theme&nbsp;<a href="https://github.com/blinkfox/hexo-theme-matery" target="_blank">Matery</a>
            <br>
            
            &nbsp;<i class="fas fa-chart-area"></i>&nbsp;站点总字数:&nbsp;<span
                class="white-color">300.9k</span>&nbsp;字
            
            
            
            
            
            
            <span id="busuanzi_container_site_pv">
                |&nbsp;<i class="far fa-eye"></i>&nbsp;总访问量:&nbsp;<span id="busuanzi_value_site_pv"
                    class="white-color"></span>&nbsp;次
            </span>
            
            
            <span id="busuanzi_container_site_uv">
                |&nbsp;<i class="fas fa-users"></i>&nbsp;总访问人数:&nbsp;<span id="busuanzi_value_site_uv"
                    class="white-color"></span>&nbsp;人
            </span>
            
            <br>
            
            <br>
            
        </div>
        <div class="col s12 m4 l4 social-link social-statis">


    <a href="mailto:2065373132@qq.com" class="tooltipped" target="_blank" data-tooltip="邮件联系我" data-position="top" data-delay="50">
        <i class="fas fa-envelope-open"></i>
    </a>







    <a href="tencent://AddContact/?fromId=50&fromSubId=1&subcmd=all&uin=2065373132" class="tooltipped" target="_blank" data-tooltip="QQ联系我: 2065373132" data-position="top" data-delay="50">
        <i class="fab fa-qq"></i>
    </a>







    <a href="/atom.xml" class="tooltipped" target="_blank" data-tooltip="RSS 订阅" data-position="top" data-delay="50">
        <i class="fas fa-rss"></i>
    </a>

</div>
    </div>
</footer>

<div class="progress-bar"></div>


    <!-- 搜索遮罩框 -->
<div id="searchModal" class="modal">
    <div class="modal-content">
        <div class="search-header">
            <span class="title"><i class="fas fa-search"></i>&nbsp;&nbsp;搜索</span>
            <input type="search" id="searchInput" name="s" placeholder="请输入搜索的关键字"
                   class="search-input">
        </div>
        <div id="searchResult"></div>
    </div>
</div>

<script type="text/javascript">
$(function () {
    var searchFunc = function (path, search_id, content_id) {
        'use strict';
        $.ajax({
            url: path,
            dataType: "xml",
            success: function (xmlResponse) {
                // get the contents from search data
                var datas = $("entry", xmlResponse).map(function () {
                    return {
                        title: $("title", this).text(),
                        content: $("content", this).text(),
                        url: $("url", this).text()
                    };
                }).get();
                var $input = document.getElementById(search_id);
                var $resultContent = document.getElementById(content_id);
                $input.addEventListener('input', function () {
                    var str = '<ul class=\"search-result-list\">';
                    var keywords = this.value.trim().toLowerCase().split(/[\s\-]+/);
                    $resultContent.innerHTML = "";
                    if (this.value.trim().length <= 0) {
                        return;
                    }
                    // perform local searching
                    datas.forEach(function (data) {
                        var isMatch = true;
                        var data_title = data.title.trim().toLowerCase();
                        var data_content = data.content.trim().replace(/<[^>]+>/g, "").toLowerCase();
                        var data_url = data.url;
                        data_url = data_url.indexOf('/') === 0 ? data.url : '/' + data_url;
                        var index_title = -1;
                        var index_content = -1;
                        var first_occur = -1;
                        // only match artiles with not empty titles and contents
                        if (data_title !== '' && data_content !== '') {
                            keywords.forEach(function (keyword, i) {
                                index_title = data_title.indexOf(keyword);
                                index_content = data_content.indexOf(keyword);
                                if (index_title < 0 && index_content < 0) {
                                    isMatch = false;
                                } else {
                                    if (index_content < 0) {
                                        index_content = 0;
                                    }
                                    if (i === 0) {
                                        first_occur = index_content;
                                    }
                                }
                            });
                        }
                        // show search results
                        if (isMatch) {
                            str += "<li><a href='" + data_url + "' class='search-result-title'>" + data_title + "</a>";
                            var content = data.content.trim().replace(/<[^>]+>/g, "");
                            if (first_occur >= 0) {
                                // cut out 100 characters
                                var start = first_occur - 20;
                                var end = first_occur + 80;
                                if (start < 0) {
                                    start = 0;
                                }
                                if (start === 0) {
                                    end = 100;
                                }
                                if (end > content.length) {
                                    end = content.length;
                                }
                                var match_content = content.substr(start, end);
                                // highlight all keywords
                                keywords.forEach(function (keyword) {
                                    var regS = new RegExp(keyword, "gi");
                                    match_content = match_content.replace(regS, "<em class=\"search-keyword\">" + keyword + "</em>");
                                });

                                str += "<p class=\"search-result\">" + match_content + "...</p>"
                            }
                            str += "</li>";
                        }
                    });
                    str += "</ul>";
                    $resultContent.innerHTML = str;
                });
            }
        });
    };

    searchFunc('/search.xml', 'searchInput', 'searchResult');
});
</script>

    <!-- 回到顶部按钮 -->
<div id="backTop" class="top-scroll">
    <a class="btn-floating btn-large waves-effect waves-light" href="#!">
        <i class="fas fa-arrow-up"></i>
    </a>
</div>


    <script src="/libs/materialize/materialize.min.js"></script>
    <script src="/libs/masonry/masonry.pkgd.min.js"></script>
    <script src="/libs/aos/aos.js"></script>
    <script src="/libs/scrollprogress/scrollProgress.min.js"></script>
    <script src="/libs/lightGallery/js/lightgallery-all.min.js"></script>
    <script src="/js/matery.js"></script>

    <!-- Baidu Analytics -->

    <!-- Baidu Push -->

<script>
    (function () {
        var bp = document.createElement('script');
        var curProtocol = window.location.protocol.split(':')[0];
        if (curProtocol === 'https') {
            bp.src = 'https://zz.bdstatic.com/linksubmit/push.js';
        } else {
            bp.src = 'http://push.zhanzhang.baidu.com/push.js';
        }
        var s = document.getElementsByTagName("script")[0];
        s.parentNode.insertBefore(bp, s);
    })();
</script>

    
    <script src="/libs/others/clicklove.js" async="async"></script>
    
    
    <script async src="/libs/others/busuanzi.pure.mini.js"></script>
    

    

    

	
    

    

    

    
    <script src="/libs/instantpage/instantpage.js" type="module"></script>
    

<script src="/live2dw/lib/L2Dwidget.min.js?094cbace49a39548bed64abff5988b05"></script><script>L2Dwidget.init({"pluginRootPath":"live2dw/","pluginJsPath":"lib/","pluginModelPath":"assets/","tagMode":false,"debug":false,"model":{"jsonPath":"live2d-widget-model-hibiki"},"display":{"position":"right","width":145,"height":315},"mobile":{"show":true,"scale":0.5},"react":{"opacityDefault":0.7,"opacityOnHover":0.8},"log":false});</script></body>

</html>
